{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n    @title: NodeFlow and Sampling\\n    @name: King\\n    @url: \\n        1、https://docs.dgl.ai/tutorials/models/5_giant_graph/1_sampling_mx.html\\n        2、https://archwalker.github.io/blog/2019/07/07/GNN-Framework-DGL-NodeFlow.html\\n    @decs: GraphSAGE\\n    \\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "    @title: NodeFlow and Sampling\n",
    "    @name: King\n",
    "    @url: \n",
    "        1、GNN 教程：DGL框架-子图和采样:https://archwalker.github.io/blog/2019/07/07/GNN-Framework-DGL-NodeFlow.html\n",
    "        2、https://docs.dgl.ai/tutorials/models/5_giant_graph/1_sampling_mx.html\n",
    "    @decs: GraphSAGE\n",
    "    \n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一、GCN\n",
    "\n",
    "In this tutorial, we will run GCN on the Reddit dataset constructed by Hamilton et al., wherein the nodes are posts and edges are established if two nodes are commented by a same user. The task is to predict the category that a post belongs to. This graph has 233K nodes, 114.6M edges and 41 categories. Let’s first load the Reddit graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\progrom\\python\\python\\python3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished data loading.\n",
      "  NumNodes: 232965\n",
      "  NumEdges: 114848857\n",
      "  NumFeats: 602\n",
      "  NumClasses: 41\n",
      "  NumTrainingSamples: 153431\n",
      "  NumValidationSamples: 23831\n",
      "  NumTestSamples: 55703\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "from dgl import DGLGraph\n",
    "from dgl.data import RedditDataset\n",
    "import mxnet as mx\n",
    "from mxnet import gluon\n",
    "\n",
    "# load dataset\n",
    "data = RedditDataset(self_loop=True)\n",
    "train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64)\n",
    "features = mx.nd.array(data.features)\n",
    "in_feats = features.shape[1]\n",
    "labels = mx.nd.array(data.labels)\n",
    "n_classes = data.num_labels\n",
    "\n",
    "# construct DGLGraph and prepare related data\n",
    "g = DGLGraph(data.graph, readonly=True)\n",
    "g.ndata['features'] = features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we define the node UDF which has a fully-connected layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NodeUpdate(gluon.Block):\n",
    "    def __init__(self, in_feats, out_feats, activation=None):\n",
    "        super(NodeUpdate, self).__init__()\n",
    "        self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)\n",
    "        self.activation = activation\n",
    "\n",
    "    def forward(self, node):\n",
    "        h = node.data['h']\n",
    "        h = self.dense(h)\n",
    "        if self.activation:\n",
    "            h = self.activation(h)\n",
    "        return {'activation': h}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In DGL, we implement GCN on the full graph with update_all in DGLGraph. The following code performs two-layer GCN on the Reddit graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\base.py:18: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'NDArray' object has no attribute 'device'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-3-7b8e78608891>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     14\u001b[0m     g.update_all(message_func=fn.copy_src(src='h', out='m'),\n\u001b[0;32m     15\u001b[0m                  \u001b[0mreduce_func\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'm'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'h'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 16\u001b[1;33m                  apply_node_func=lambda node: {'h': layers[i](node)['activation']})\n\u001b[0m\u001b[0;32m     17\u001b[0m     \u001b[0mh\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'h'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\graph.py\u001b[0m in \u001b[0;36mupdate_all\u001b[1;34m(self, message_func, reduce_func, apply_node_func)\u001b[0m\n\u001b[0;32m   2549\u001b[0m                                           \u001b[0mreduce_func\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2550\u001b[0m                                           apply_func=apply_node_func)\n\u001b[1;32m-> 2551\u001b[1;33m             \u001b[0mRuntime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprog\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2552\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2553\u001b[0m     def prop_nodes(self,\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\runtime\\runtime.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(prog)\u001b[0m\n\u001b[0;32m      8\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mexe\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mprog\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexecs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m             \u001b[1;31m#prog.pprint_exe(exe)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m             \u001b[0mexe\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\runtime\\ir\\executor.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    453\u001b[0m         \u001b[0mspA_ctx_fn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mspA\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    454\u001b[0m         \u001b[0mB\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mB\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 455\u001b[1;33m         \u001b[0mctx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mB\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    456\u001b[0m         \u001b[0mspA\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mspA_ctx_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    457\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mB\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\backend\\pytorch\\tensor.py\u001b[0m in \u001b[0;36mcontext\u001b[1;34m(input)\u001b[0m\n\u001b[0;32m     67\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     68\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcontext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 69\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     70\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     71\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mastype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mty\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'NDArray' object has no attribute 'device'"
     ]
    }
   ],
   "source": [
    "# number of GCN layers\n",
    "L = 2\n",
    "# number of hidden units of a fully connected layer\n",
    "n_hidden = 64\n",
    "\n",
    "layers = [NodeUpdate(g.ndata['features'].shape[1], n_hidden, mx.nd.relu),\n",
    "          NodeUpdate(n_hidden, n_hidden, mx.nd.relu)]\n",
    "for layer in layers:\n",
    "    layer.initialize()\n",
    "\n",
    "h = g.ndata['features']\n",
    "for i in range(L):\n",
    "    g.ndata['h'] = h\n",
    "    g.update_all(message_func=fn.copy_src(src='h', out='m'),\n",
    "                 reduce_func=fn.sum(msg='m', out='h'),\n",
    "                 apply_node_func=lambda node: {'h': layers[i](node)['activation']})\n",
    "    h = g.ndata.pop('h')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
